JAX and Functional ML
Reading time: ~45 minutes | Level: Advanced
The Puzzle
Before reading further, predict what this code prints:
import jax
import jax.numpy as jnp
x = jnp.array([1.0, 2.0, 3.0])
def f(x):
print("tracing f") # <-- what happens here?
return jnp.sum(x ** 2)
f_jit = jax.jit(f)
result1 = f_jit(x)
result2 = f_jit(x)
result3 = f_jit(jnp.array([4.0, 5.0, 6.0]))
The output is:
tracing f
Printed exactly once, not three times. The second and third calls use the compiled version without re-running Python at all. And this is the entire point of JAX's execution model.
When you call jax.jit(f), JAX runs f once with abstract tracers instead of real values, records every operation into an XLA computation, compiles it to native machine code (or GPU/TPU kernels), and caches the result keyed by the shape and dtype of the inputs. Subsequent calls with the same shape/dtype go directly to the compiled kernel -- zero Python overhead.
Understanding this tracing model is the key to understanding every JAX design decision, restriction, and performance characteristic.
Why This Matters
JAX was designed at Google Brain for research that needs:
- Hardware-agnostic vectorisation: write a function for one sample;
vmapmakes it a batched function with no code changes - Higher-order derivatives:
grad(grad(f))just works; Hessians and Jacobians are first-class - XLA compilation: computation graphs are compiled to highly optimised native code for CPU, GPU, and TPU
- Functional purity: side-effect-free functions compose cleanly for distributed and parallelised training
JAX is the foundation of Google DeepMind's research stack, and it powers Flax, Haiku, Optax, and Equinox -- the libraries behind some of the most important open-source models of the last few years. Understanding JAX makes you a better ML engineer even if you primarily use PyTorch, because the functional thinking it demands clarifies what autograd actually does.
1. The Core Abstraction: Functional Transformations
JAX's four fundamental transformations:
These transformations are composable. jax.jit(jax.vmap(jax.grad(f))) is valid: differentiate f, vectorise the gradient function, then JIT-compile the batched gradient computation. This composition is the heart of how JAX enables research that would require thousands of lines in other frameworks.
2. jit: XLA Compilation and Tracing
import jax
import jax.numpy as jnp
import time
# A numerically intensive function
def softmax_cross_entropy(logits: jnp.ndarray, labels: jnp.ndarray) -> jnp.ndarray:
"""Cross-entropy loss via log-softmax."""
log_probs = logits - jax.nn.logsumexp(logits, axis=-1, keepdims=True)
return -jnp.mean(jnp.sum(labels * log_probs, axis=-1))
# JIT-compile it
fast_loss = jax.jit(softmax_cross_entropy)
# Benchmark: interpreted vs compiled
key = jax.random.PRNGKey(0)
logits = jax.random.normal(key, (512, 1000)) # batch=512, classes=1000
labels = jax.nn.one_hot(jax.random.randint(key, (512,), 0, 1000), 1000)
# Warmup (triggers compilation on first call)
_ = fast_loss(logits, labels).block_until_ready() # block_until_ready() ensures async op is done
t0 = time.perf_counter()
for _ in range(100):
result = fast_loss(logits, labels).block_until_ready()
print(f"JIT: {(time.perf_counter() - t0)*1000:.2f} ms / 100 calls")
t0 = time.perf_counter()
for _ in range(100):
result = softmax_cross_entropy(logits, labels).block_until_ready()
print(f"No JIT: {(time.perf_counter() - t0)*1000:.2f} ms / 100 calls")
# JIT is typically 5-20x faster for non-trivial functions on CPU;
# the speedup is larger on GPU/TPU because kernel launch overhead is eliminated.
What the Tracer Sees
During tracing, JAX replaces concrete values with ShapedArray tracers -- abstract types that carry only shape and dtype, not values. Every JAX operation records itself as a node in the XLA HLO (High-Level Operations) graph.
import jax
from jax import make_jaxpr
def f(x):
return jnp.sin(x) + jnp.cos(x)
# Inspect the XLA computation that jit would compile
print(make_jaxpr(f)(jnp.ones(4)))
# { lambda ; a:f32[4]. let
# b:f32[4] = sin a
# c:f32[4] = cos a
# d:f32[4] = add b c
# in (d,) }
Static vs Dynamic Arguments
The tracer sees shapes and dtypes but not values. If your function branches on the VALUE of an input (not just its shape), the branch is baked into the compiled function for that specific value. This is a fundamental constraint:
# BAD: this Python conditional branches on the value of a JAX array
def bad_fn(x, condition):
if condition > 0: # condition is a JAX array; its value is abstract during tracing
return x * 2
else:
return x / 2
# jax.jit(bad_fn)(x, jnp.array(1.0)) -- raises ConcretizationTypeError
# GOOD option 1: use jax.lax.cond (a differentiable conditional)
def good_fn(x, condition):
return jax.lax.cond(
condition > 0,
lambda x: x * 2, # branch if True
lambda x: x / 2, # branch if False
x,
)
# GOOD option 2: mark condition as static so it's baked into the compiled function
from functools import partial
@partial(jax.jit, static_argnums=(1,)) # argument 1 (condition) is static
def fn_with_static(x, condition: bool):
if condition:
return x * 2
return x / 2
result = fn_with_static(jnp.ones(4), True) # compiled with condition=True
# If you call fn_with_static(x, False), JAX will retrace and compile a second version
Rule of thumb: anything that controls Python control flow (loops, conditionals) must be either static (declared via static_argnums) or replaced with JAX control flow primitives (lax.cond, lax.while_loop, lax.fori_loop).
3. grad: Automatic Differentiation
import jax
import jax.numpy as jnp
# grad computes the gradient of a scalar-valued function
def quadratic(x: jnp.ndarray) -> float:
"""f(x) = sum(x^2)"""
return jnp.sum(x ** 2)
grad_f = jax.grad(quadratic) # df/dx = 2x
x = jnp.array([1.0, 2.0, 3.0])
print(grad_f(x)) # [2. 4. 6.]
# grad differentiates with respect to the FIRST argument by default.
# Use argnums to differentiate with respect to a different argument.
def loss(params, X, y):
"""Simple linear regression loss."""
W, b = params
pred = X @ W + b
return jnp.mean((pred - y) ** 2)
W = jnp.ones((3, 1))
b = jnp.zeros(1)
X = jnp.array([[1., 2., 3.], [4., 5., 6.]])
y = jnp.array([[1.], [2.]])
# Differentiate with respect to params (argument 0)
grad_loss = jax.grad(loss, argnums=0)
grads = grad_loss((W, b), X, y)
print(grads[0].shape, grads[1].shape) # (3,1) (1,)
# value_and_grad: returns (value, gradient) in one pass
# More efficient than computing both separately
value_and_grad_fn = jax.value_and_grad(loss, argnums=0)
loss_val, grads = value_and_grad_fn((W, b), X, y)
Higher-Order Derivatives
import jax
import jax.numpy as jnp
def f(x: float) -> float:
return jnp.sin(x)
# First derivative: cos(x)
df = jax.grad(f)
print(df(jnp.array(0.0))) # ~1.0
# Second derivative: -sin(x)
ddf = jax.grad(jax.grad(f))
print(ddf(jnp.array(0.0))) # ~0.0
# Jacobian: for vector-valued functions, use jax.jacobian
def g(x: jnp.ndarray) -> jnp.ndarray:
return jnp.array([x[0]**2, x[0]*x[1], x[1]**3])
J = jax.jacobian(g)(jnp.array([2.0, 3.0]))
print(J) # shape (3, 2) -- 3 outputs, 2 inputs
# Hessian: jacobian of the gradient
def loss_scalar(x: jnp.ndarray) -> float:
return jnp.sum(x ** 4)
H = jax.hessian(loss_scalar)(jnp.array([1.0, 2.0]))
print(H) # shape (2, 2) -- second derivative matrix
# [[12. 0.]
# [ 0. 48.]] -- 12*x[0]^2 = 12, 12*x[1]^2 = 48
4. vmap: Auto-Vectorisation
vmap is JAX's most underrated transformation. It lets you write functions for a single example and automatically vectorise them over a batch dimension -- without explicit batch indices in your code.
import jax
import jax.numpy as jnp
# Write the loss for a SINGLE example (no batch dimension)
def loss_single(params, x_single: jnp.ndarray, y_single: float) -> float:
"""MSE loss for one sample."""
W, b = params
pred = jnp.dot(W, x_single) + b
return (pred - y_single) ** 2
# Batched version: apply loss_single over a batch of X, y
# in_axes=(None, 0, 0) means:
# - params is not batched (None)
# - x is batched along axis 0
# - y is batched along axis 0
batched_loss = jax.vmap(loss_single, in_axes=(None, 0, 0))
W = jnp.ones(5)
b = jnp.zeros(())
X_batch = jnp.ones((32, 5)) # 32 samples
y_batch = jnp.ones(32)
losses = batched_loss((W, b), X_batch, y_batch)
print(losses.shape) # (32,) -- one loss per sample
# Mean over the batch
mean_loss = jnp.mean(losses)
# Why vmap instead of explicit broadcasting?
# 1. You write cleaner single-sample logic without tracking batch dims
# 2. XLA compiles the vectorised computation as native SIMD / tensor ops
# 3. You can compose: jax.grad(jax.vmap(loss_single)) gives per-sample gradients
# Per-sample gradients (useful in differential privacy, influence functions)
grad_single = jax.grad(loss_single) # gradient for one sample
per_sample_grads = jax.vmap(grad_single, in_axes=(None, 0, 0)) # batch it
grads = per_sample_grads((W, b), X_batch, y_batch)
print(grads[0].shape) # (32, 5) -- gradient of W for each of 32 samples
vmap composition with jit:
# Typical usage: vmap for batching, jit for compilation speed
fast_batched_loss = jax.jit(jax.vmap(loss_single, in_axes=(None, 0, 0)))
# The order matters:
# jit(vmap(f)) -- compile the vectorised function (recommended)
# vmap(jit(f)) -- JIT each individual function call inside vmap (less efficient)
5. pmap: Multi-Device Parallelism
pmap shards computation across multiple devices (typically multiple GPUs or TPU pods). Each device runs its own copy of the function, and collective operations (like lax.pmean) aggregate across devices.
import jax
import jax.numpy as jnp
from jax import pmap, lax
# Check available devices
print(jax.local_devices()) # [CpuDevice(id=0)] or multiple GPUs
n_devices = jax.local_device_count()
def train_step(params, X_shard, y_shard):
"""One gradient step on a shard of the data (one device's portion)."""
def loss_fn(params):
W, b = params
pred = X_shard @ W + b
return jnp.mean((pred - y_shard) ** 2)
loss, grads = jax.value_and_grad(loss_fn)(params)
# Average gradients across all devices
# lax.pmean is a collective op: each device receives the mean of all devices' values
grads = jax.tree_util.tree_map(lambda g: lax.pmean(g, axis_name="batch"), grads)
loss = lax.pmean(loss, axis_name="batch")
return loss, grads
# pmap the training step, naming the device axis "batch"
parallel_train_step = pmap(train_step, axis_name="batch")
# Replicate params on each device
W = jnp.ones((10, 1))
b = jnp.zeros(1)
params = (W, b)
replicated_params = jax.tree_util.tree_map(
lambda x: jnp.broadcast_to(x, (n_devices,) + x.shape),
params
)
# Shard data: add a leading axis of size n_devices
X = jnp.ones((n_devices * 64, 10))
y = jnp.ones((n_devices * 64, 1))
X_sharded = X.reshape(n_devices, 64, 10)
y_sharded = y.reshape(n_devices, 64, 1)
loss, grads = parallel_train_step(replicated_params, X_sharded, y_sharded)
# Each element in grads has shape (n_devices, ...) -- same on every device (averaged)
In practice, most researchers use pmap through higher-level libraries (Flax's flax.training.train_state + jax.lax.psum patterns, or Equinox with eqx.filter_pmap). The above shows the primitives.
6. Pure Functions: The Fundamental Constraint
JAX's transformations (jit, grad, vmap, pmap) only work correctly with pure functions: functions with no side effects and no dependence on external mutable state.
import jax
import jax.numpy as jnp
# BAD: reading Python-level mutable state inside a jitted function
counter = 0
@jax.jit
def bad_fn(x):
global counter
counter += 1 # side effect -- this runs only during tracing, not at call time!
return x * 2
bad_fn(jnp.array(1.0))
bad_fn(jnp.array(1.0))
print(counter) # 1 -- not 2! The Python code ran only during the one tracing call.
# BAD: storing state in a Python variable and reading it in jitted code
weights = jnp.ones(10)
@jax.jit
def bad_fn2(x):
return x + weights # weights is captured from Python scope
# If weights changes, bad_fn2 still uses the old captured value.
# The jitted function closed over the VALUE at trace time, not a reference.
weights = jnp.zeros(10) # changing weights after jitting has no effect
# GOOD: pass all state as explicit function arguments
@jax.jit
def good_fn(x, weights):
return x + weights
result = good_fn(jnp.ones(10), jnp.zeros(10)) # weights passed explicitly
Python Side Effects During Tracing
import jax
import jax.numpy as jnp
@jax.jit
def f(x):
print("This runs ONCE during tracing") # not at every call!
y = x * 2
print(f"y is: {y}") # y is a Traced<ShapedArray(float32[])>
return y
f(jnp.array(1.0)) # prints "This runs ONCE during tracing" and the tracer repr
f(jax.array(2.0)) # prints NOTHING -- uses cached compilation
This is why logging inside jitted functions does not work. Use jax.debug.print for values that need to be printed during compiled execution:
@jax.jit
def f_with_debug(x):
jax.debug.print("x value: {x}", x=x) # prints concrete values at runtime
return x ** 2
7. JAX Random: PRNG Keys
JAX's random number generator is fundamentally different from NumPy's and PyTorch's -- it is stateless and functional.
import jax
import jax.numpy as jnp
# In NumPy: global mutable state
import numpy as np
np.random.seed(42)
a = np.random.normal() # advances internal state
b = np.random.normal() # different from a because state changed
# In JAX: explicit key threading -- no global state
key = jax.random.PRNGKey(42) # a deterministic 2-element array
# Split produces independent subkeys -- the original key is not consumed
key, subkey1, subkey2 = jax.random.split(key, num=3)
a = jax.random.normal(subkey1) # (subkey1, shape, dtype)
b = jax.random.normal(subkey2) # independent of a
# The same key always produces the same value
x1 = jax.random.normal(jax.random.PRNGKey(0))
x2 = jax.random.normal(jax.random.PRNGKey(0))
print(x1 == x2) # True -- fully deterministic
# Why functional PRNG?
# 1. Reproducibility: passing the same key always gives the same result,
# regardless of execution order or parallelism.
# 2. Parallelism: vmap and pmap need independent randomness per element/device.
# Global state cannot be safely shared across parallel executions.
# 3. Composability: functions that take a key as input are pure.
def sample_batch(key: jax.Array, n: int) -> jnp.ndarray:
"""Reproducible batch sampling."""
return jax.random.normal(key, shape=(n, 10))
# Typical training pattern: thread keys through the loop
key = jax.random.PRNGKey(0)
for step in range(100):
key, subkey = jax.random.split(key)
batch = sample_batch(subkey, n=64)
# ... train step ...
8. JAX NumPy API Differences
JAX's jax.numpy is largely compatible with NumPy but has important differences:
import jax
import jax.numpy as jnp
import numpy as np
# 1. IMMUTABILITY: JAX arrays cannot be modified in place
x = jnp.array([1.0, 2.0, 3.0])
try:
x[0] = 99.0 # TypeError: JAX arrays are immutable
except TypeError as e:
print(e)
# Use .at[].set() for functional updates
x_new = x.at[0].set(99.0) # returns a new array; x is unchanged
print(x_new) # [99. 2. 3.]
print(x) # [ 1. 2. 3.] -- unchanged
# Other .at operations
x.at[1:].add(10) # returns new array with 10 added to indices 1:
x.at[0].max(50) # returns new array with max(x[0], 50)
# 2. DTYPE DEFAULTS: JAX defaults to 32-bit; NumPy defaults to 64-bit
np_arr = np.array([1.0, 2.0]) # dtype=float64
jnp_arr = jnp.array([1.0, 2.0]) # dtype=float32
# 3. OUT-OF-BOUNDS INDEXING: NumPy raises IndexError; JAX silently clips
a = jnp.array([1, 2, 3])
print(a[100]) # returns 3 -- last element, no error!
# This is intentional: JAX's XLA compilation cannot handle exceptions
# that vary by value. For debugging, use jax.debug.print or check in Python
# before the jitted function.
# 4. DEVICE ARRAYS: operations return device arrays, not CPU numpy arrays
result = jnp.sum(jnp_arr)
print(type(result)) # jaxlib.xla_extension.ArrayImpl
# Convert to numpy for plotting/sklearn
result_np = np.array(result) # triggers device->host transfer
# 5. CONTROL FLOW: standard Python if/for/while does not work as expected
# inside jitted functions. Use jax.lax primitives:
def scan_example(carry, x):
"""Accumulate: carry is the running sum, x is the current element."""
carry = carry + x
return carry, carry # (new_carry, output)
init = 0.0
inputs = jnp.array([1.0, 2.0, 3.0, 4.0])
final, outputs = jax.lax.scan(scan_example, init, inputs)
print(outputs) # [1. 3. 6. 10.] -- cumulative sums
# lax.scan is the jittable equivalent of a for loop over a sequence
9. The Haiku/Flax/Optax/Equinox Ecosystem
Pure functions with explicit state are elegant but verbose. The neural network libraries solve this by providing structured ways to manage parameters.
Flax Example (NNX API, recommended in 2025)
import jax
import jax.numpy as jnp
import optax
from flax import nnx
# Define a model using Flax NNX -- closest API to PyTorch nn.Module
class MLP(nnx.Module):
def __init__(self, in_features: int, hidden: int, out_features: int, rngs: nnx.Rngs):
self.fc1 = nnx.Linear(in_features, hidden, rngs=rngs)
self.fc2 = nnx.Linear(hidden, out_features, rngs=rngs)
self.drop = nnx.Dropout(rate=0.2, rngs=rngs)
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
x = nnx.relu(self.fc1(x))
x = self.drop(x)
return self.fc2(x)
# Initialise model -- creates parameters stored inside the module
key = jax.random.PRNGKey(0)
model = MLP(10, 64, 4, rngs=nnx.Rngs(0))
# Optax optimizer (functional -- holds no internal state itself)
optimizer = nnx.Optimizer(model, optax.adamw(learning_rate=1e-3))
# Training step
@nnx.jit # JIT-compile the training step (NNX-aware jit)
def train_step(model, optimizer, X, y):
def loss_fn(model):
logits = model(X)
return optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()
loss, grads = nnx.value_and_grad(loss_fn)(model)
optimizer.update(grads)
return loss
# Run a step
X = jnp.ones((32, 10))
y = jnp.zeros(32, dtype=jnp.int32)
loss = train_step(model, optimizer, X, y)
print(f"Loss: {loss:.4f}")
Optax Optimizers
import optax
import jax
import jax.numpy as jnp
# All Optax optimizers are functional: they return (init_fn, update_fn)
optimizer = optax.chain(
optax.clip_by_global_norm(1.0), # gradient clipping
optax.adamw(
learning_rate=optax.warmup_cosine_decay_schedule(
init_value=0.0,
peak_value=1e-3,
warmup_steps=100,
decay_steps=1000,
),
weight_decay=1e-4,
),
)
# Initialise optimizer state (holds momentum, second moment estimates, etc.)
params = {"W": jnp.ones((10, 4)), "b": jnp.zeros(4)}
opt_state = optimizer.init(params)
def update(params, opt_state, grads):
updates, new_state = optimizer.update(grads, opt_state, params)
new_params = optax.apply_updates(params, updates)
return new_params, new_state
# This pattern is what Flax and Haiku training loops use internally
10. When to Use JAX vs PyTorch
Choose JAX when:
- You need to differentiate through optimisation (meta-learning, implicit differentiation)
- You are targeting TPU hardware (JAX's XLA is the native TPU stack)
- You want to compose transformations in ways PyTorch does not support (
vmap(grad(f))for per-sample gradients;hessian(f)for second-order methods) - Your code has a mathematical flavour where functional purity improves reasoning
Choose PyTorch when:
- You are using HuggingFace
transformers,datasets,peft, ortrl-- they are PyTorch-native - Your team has an existing PyTorch codebase with custom CUDA kernels
- You need
torchserve,torch.onnx, ortorch.compilefor production deployment - You are doing reinforcement learning with environments that have Python-level side effects (JAX's pure function requirement is awkward here)
In practice: the frontier is blurring. torch.compile brings XLA-like compilation to PyTorch. JAX-based libraries (Flax NNX) are becoming more PyTorch-like. The mental models from both -- functional composition, explicit device management, computation graph awareness -- are useful regardless of which you use.
11. Practical JAX: A Complete Training Loop
import jax
import jax.numpy as jnp
import optax
import numpy as np
from functools import partial
# --- Pure-function model (no library) ---
def init_params(key: jax.Array, in_features: int, hidden: int, out_features: int) -> dict:
"""Xavier initialisation -- returns a pytree of parameters."""
k1, k2 = jax.random.split(key)
W1 = jax.random.normal(k1, (in_features, hidden)) * jnp.sqrt(2 / in_features)
W2 = jax.random.normal(k2, (hidden, out_features)) * jnp.sqrt(2 / hidden)
return {"W1": W1, "b1": jnp.zeros(hidden),
"W2": W2, "b2": jnp.zeros(out_features)}
def predict(params: dict, X: jnp.ndarray) -> jnp.ndarray:
"""Forward pass: linear -> relu -> linear."""
h = jax.nn.relu(X @ params["W1"] + params["b1"])
return h @ params["W2"] + params["b2"]
def cross_entropy_loss(params: dict, X: jnp.ndarray, y: jnp.ndarray) -> float:
logits = predict(params, X)
log_probs = logits - jax.nn.logsumexp(logits, axis=-1, keepdims=True)
return -jnp.mean(jnp.sum(jax.nn.one_hot(y, logits.shape[-1]) * log_probs, axis=-1))
# JIT-compiled training step
@partial(jax.jit)
def train_step(params, opt_state, X, y, optimizer):
loss, grads = jax.value_and_grad(cross_entropy_loss)(params, X, y)
updates, new_opt_state = optimizer.update(grads, opt_state, params)
new_params = optax.apply_updates(params, updates)
return new_params, new_opt_state, loss
# --- Setup ---
key = jax.random.PRNGKey(0)
params = init_params(key, in_features=20, hidden=64, out_features=5)
optimizer = optax.adamw(learning_rate=1e-3)
opt_state = optimizer.init(params)
# --- Synthetic data ---
rng = np.random.default_rng(42)
X_np = rng.normal(size=(500, 20)).astype(np.float32)
y_np = rng.integers(0, 5, size=500).astype(np.int32)
# --- Training loop ---
batch_size = 64
n_steps = 200
key = jax.random.PRNGKey(1)
for step in range(n_steps):
key, subkey = jax.random.split(key)
idx = jax.random.choice(subkey, 500, shape=(batch_size,), replace=False)
X_b = jnp.array(X_np[np.array(idx)])
y_b = jnp.array(y_np[np.array(idx)])
params, opt_state, loss = train_step(params, opt_state, X_b, y_b, optimizer)
if (step + 1) % 50 == 0:
print(f"Step {step+1:3d} loss={loss:.4f}")
# --- Inference ---
X_test = jnp.array(X_np[:10])
logits = predict(params, X_test)
preds = jnp.argmax(logits, axis=-1)
print(f"Predictions: {preds}")
Key Takeaways
- JAX traces functions once with abstract tracers to build an XLA computation graph, then compiles and caches it. Python code inside jitted functions runs at trace time, not at call time.
- Pure functions are mandatory. Side effects, mutable closures, and dynamic Python control flow break jit/grad/vmap. Use
lax.cond,lax.while_loop,lax.scanfor jittable control flow;jax.debug.printfor runtime values. jax.gradreturns a function; it does not compute a gradient immediately. Compose it freely:jax.grad(jax.grad(f))for Hessians,jax.vmap(jax.grad(f))for per-sample gradients.vmapeliminates batch indices from your thinking. Write functions for a single example, then vmap them. The compiled result is as efficient as hand-written batched code.- JAX PRNG requires explicit key threading.
jax.random.splitproduces independent subkeys. The same key always produces the same value -- randomness is fully deterministic and reproducible. - JAX arrays are immutable. Use
.at[].set()for functional in-place updates. Out-of-bounds indexing silently clips rather than raising an error. - Haiku, Flax NNX, and Equinox are the main neural network libraries. All expose the parameter tree as an explicit Python dictionary/pytree that you manage, rather than hiding it inside an object.
- Optax is the standard JAX optimizer library. All optimizers are functional:
(init_fn, update_fn). Chain them withoptax.chainfor gradient clipping + adaptive optimisation + learning rate schedules. - JAX excels at research requiring higher-order derivatives, per-sample gradients, TPU targeting, or functional composition. PyTorch excels for production deployment, the HuggingFace ecosystem, and teams with existing infrastructure.
Practice Problems
Problem 1 -- JIT Tracing Behaviour
Write a function count_calls(x) that increments a global counter and returns x * 2. JIT-compile it. Call it 5 times with the same shape, then call it once with a different shape. Print the counter. Explain what you observe. Then rewrite the function to use jax.debug.print to print the counter at runtime and verify the difference.
Problem 2 -- Custom Gradient
Implement the straight-through estimator (STE) for a quantisation function using jax.custom_vjp. The forward pass rounds x to the nearest integer; the backward pass passes gradients through unchanged (as if the round was an identity). Test that gradients flow correctly through the quantised layer.
Problem 3 -- vmap for Per-Sample Gradients
Given a model f(params, x) and a dataset of 100 samples, compute the gradient of the loss with respect to params for each individual sample (not the mean gradient -- per-sample). Do this using jax.vmap(jax.grad(loss_single)). Verify that the mean of per-sample gradients equals the gradient of the mean loss. Profile both approaches and compare their runtime.
Problem 4 -- MAML Inner Loop
Implement one step of Model-Agnostic Meta-Learning (MAML): given support data (X_s, y_s), compute the gradient of the loss, take one gradient step to produce adapted parameters, then evaluate the adapted parameters on query data (X_q, y_q). The outer loss gradient must backpropagate through the inner gradient step. This requires jax.grad applied to a function that itself calls jax.grad -- verify that the outer gradient is non-None and finite.
Problem 5 -- Benchmarking vmap vs Manual Batch Write the same function in two ways: (a) with explicit batch indexing using standard matrix operations, and (b) as a single-example function vmapped over the batch axis. Benchmark both on batch sizes 32, 256, 2048, and 16384. Plot throughput (samples/second) vs batch size. Investigate whether and when the two approaches differ in speed.
